{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Currently \"Part B: Encrypted Aggregation\" of this notebook is not working because numpy is not hooked yet.\n",
    "\n",
    "To check the status please visit https://github.com/OpenMined/PySyft/issues/2771.\n",
    "\n",
    "This message will be removed as soon as Part B is working again."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Part 10: Federated Learning with Encrypted Gradient Aggregation\n",
    "\n",
    "In the last few sections, we've been learning about encrypted computation by building several simple programs. In this section, we're going to return to the [Federated Learning Demo of Part 4](https://github.com/OpenMined/PySyft/blob/dev/examples/tutorials/Part%204%20-%20Federated%20Learning%20via%20Trusted%20Aggregator.ipynb), where we had a \"trusted aggregator\" who was responsible for averaging the model updates from multiple workers.\n",
    "\n",
    "We will now use our new tools for encrypted computation to remove this trusted aggregator because it is less than ideal as it assumes that we can find someone trustworthy enough to have access to this sensitive information. This is not always the case.\n",
    "\n",
    "Thus, in this notebook, we will show how one can use SMPC to perform secure aggregation such that we don't need a \"trusted aggregator\".\n",
    "\n",
    "Authors:\n",
    "- Theo Ryffel - Twitter: [@theoryffel](https://twitter.com/theoryffel)\n",
    "- Andrew Trask - Twitter: [@iamtrask](https://twitter.com/iamtrask)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Section 1: Normal Federated Learning\n",
    "\n",
    "First, here is some code which performs classic federated learning on the Boston Housing Dataset. This section of code is broken into several sections.\n",
    "\n",
    "### Setting Up"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "\n",
    "class Parser:\n",
    "    \"\"\"Parameters for training\"\"\"\n",
    "    def __init__(self):\n",
    "        self.epochs = 10\n",
    "        self.lr = 0.001\n",
    "        self.test_batch_size = 8\n",
    "        self.batch_size = 8\n",
    "        self.log_interval = 10\n",
    "        self.seed = 1\n",
    "    \n",
    "args = Parser()\n",
    "\n",
    "torch.manual_seed(args.seed)\n",
    "kwargs = {}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading the Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('../data/BostonHousing/boston_housing.pickle','rb') as f:\n",
    "    ((X, y), (X_test, y_test)) = pickle.load(f)\n",
    "\n",
    "X = torch.from_numpy(X).float()\n",
    "y = torch.from_numpy(y).float()\n",
    "X_test = torch.from_numpy(X_test).float()\n",
    "y_test = torch.from_numpy(y_test).float()\n",
    "# preprocessing\n",
    "mean = X.mean(0, keepdim=True)\n",
    "dev = X.std(0, keepdim=True)\n",
    "mean[:, 3] = 0. # the feature at column 3 is binary,\n",
    "dev[:, 3] = 1.  # so we don't standardize it\n",
    "X = (X - mean) / dev\n",
    "X_test = (X_test - mean) / dev\n",
    "train = TensorDataset(X, y)\n",
    "test = TensorDataset(X_test, y_test)\n",
    "train_loader = DataLoader(train, batch_size=args.batch_size, shuffle=True, **kwargs)\n",
    "test_loader = DataLoader(test, batch_size=args.test_batch_size, shuffle=True, **kwargs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Neural Network Structure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.fc1 = nn.Linear(13, 32)\n",
    "        self.fc2 = nn.Linear(32, 24)\n",
    "        self.fc3 = nn.Linear(24, 1)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.view(-1, 13)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x\n",
    "\n",
    "model = Net()\n",
    "optimizer = optim.SGD(model.parameters(), lr=args.lr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hooking PyTorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import syft as sy\n",
    "\n",
    "hook = sy.TorchHook(torch)\n",
    "bob = sy.VirtualWorker(hook, id=\"bob\")\n",
    "alice = sy.VirtualWorker(hook, id=\"alice\")\n",
    "james = sy.VirtualWorker(hook, id=\"james\")\n",
    "\n",
    "compute_nodes = [bob, alice]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Send data to the workers** <br>\n",
    "Usually they would already have it, this is just for demo purposes that we send it manually"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_distributed_dataset = []\n",
    "\n",
    "for batch_idx, (data,target) in enumerate(train_loader):\n",
    "    data = data.send(compute_nodes[batch_idx % len(compute_nodes)])\n",
    "    target = target.send(compute_nodes[batch_idx % len(compute_nodes)])\n",
    "    train_distributed_dataset.append((data, target))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training Function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(epoch):\n",
    "    model.train()\n",
    "    for batch_idx, (data,target) in enumerate(train_distributed_dataset):\n",
    "        worker = data.location\n",
    "        model.send(worker)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        # update the model\n",
    "        pred = model(data)\n",
    "        loss = F.mse_loss(pred.view(-1), target)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        model.get()\n",
    "            \n",
    "        if batch_idx % args.log_interval == 0:\n",
    "            loss = loss.get()\n",
    "            print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
    "                epoch, batch_idx * data.shape[0], len(train_loader),\n",
    "                       100. * batch_idx / len(train_loader), loss.item()))\n",
    "        \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Testing Function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test():\n",
    "    model.eval()\n",
    "    test_loss = 0\n",
    "    for data, target in test_loader:\n",
    "        output = model(data)\n",
    "        test_loss += F.mse_loss(output.view(-1), target, reduction='sum').item() # sum up batch loss\n",
    "        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability\n",
    "        \n",
    "    test_loss /= len(test_loader.dataset)\n",
    "    print('\\nTest set: Average loss: {:.4f}\\n'.format(test_loss))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training the Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = time.time()\n",
    "\n",
    "for epoch in range(1, args.epochs + 1):\n",
    "    train(epoch)\n",
    "\n",
    "    \n",
    "total_time = time.time() - t\n",
    "print('Total', round(total_time, 2), 's')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Calculating Performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Section 2: Adding Encrypted Aggregation\n",
    "\n",
    "Now we're going to slightly modify this example to aggregate gradients using encryption. The main piece that's different is really 1 or 2 lines of code in the `train()` function, which we'll point out. For the moment, let's re-process our data and initialize a model for bob and alice."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "remote_dataset = (list(),list())\n",
    "\n",
    "train_distributed_dataset = []\n",
    "\n",
    "for batch_idx, (data,target) in enumerate(train_loader):\n",
    "    data = data.send(compute_nodes[batch_idx % len(compute_nodes)])\n",
    "    target = target.send(compute_nodes[batch_idx % len(compute_nodes)])\n",
    "    remote_dataset[batch_idx % len(compute_nodes)].append((data, target))\n",
    "\n",
    "def update(data, target, model, optimizer):\n",
    "    model.send(data.location)\n",
    "    optimizer.zero_grad()\n",
    "    pred = model(data)\n",
    "    loss = F.mse_loss(pred.view(-1), target)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    return model\n",
    "\n",
    "bobs_model = Net()\n",
    "alices_model = Net()\n",
    "\n",
    "bobs_optimizer = optim.SGD(bobs_model.parameters(), lr=args.lr)\n",
    "alices_optimizer = optim.SGD(alices_model.parameters(), lr=args.lr)\n",
    "\n",
    "models = [bobs_model, alices_model]\n",
    "params = [list(bobs_model.parameters()), list(alices_model.parameters())]\n",
    "optimizers = [bobs_optimizer, alices_optimizer]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Building our Training Logic\n",
    "\n",
    "The only **real** difference is inside of this train method. Let's walk through it step-by-step.\n",
    "\n",
    "### Part A: Train:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# this is selecting which batch to train on\n",
    "data_index = 0\n",
    "# update remote models\n",
    "# we could iterate this multiple times before proceeding, but we're only iterating once per worker here\n",
    "for remote_index in range(len(compute_nodes)):\n",
    "    data, target = remote_dataset[remote_index][data_index]\n",
    "    models[remote_index] = update(data, target, models[remote_index], optimizers[remote_index])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Part B: Encrypted Aggregation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create a list where we'll deposit our encrypted model average\n",
    "new_params = list()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# iterate through each parameter\n",
    "for param_i in range(len(params[0])):\n",
    "\n",
    "    # for each worker\n",
    "    spdz_params = list()\n",
    "    for remote_index in range(len(compute_nodes)):\n",
    "        \n",
    "        # select the identical parameter from each worker and copy it\n",
    "        copy_of_parameter = params[remote_index][param_i].copy()\n",
    "        \n",
    "        # since SMPC can only work with integers (not floats), we need\n",
    "        # to use Integers to store decimal information. In other words,\n",
    "        # we need to use \"Fixed Precision\" encoding.\n",
    "        fixed_precision_param = copy_of_parameter.fix_precision()\n",
    "        \n",
    "        # now we encrypt it on the remote machine. Note that \n",
    "        # fixed_precision_param is ALREADY a pointer. Thus, when\n",
    "        # we call share, it actually encrypts the data that the\n",
    "        # data is pointing TO. This returns a POINTER to the \n",
    "        # MPC secret shared object, which we need to fetch.\n",
    "        encrypted_param = fixed_precision_param.share(bob, alice, crypto_provider=james)\n",
    "        \n",
    "        # now we fetch the pointer to the MPC shared value\n",
    "        param = encrypted_param.get()\n",
    "        \n",
    "        # save the parameter so we can average it with the same parameter\n",
    "        # from the other workers\n",
    "        spdz_params.append(param)\n",
    "\n",
    "    # average params from multiple workers, fetch them to the local machine\n",
    "    # decrypt and decode (from fixed precision) back into a floating point number\n",
    "    new_param = (spdz_params[0] + spdz_params[1]).get().float_precision()/2\n",
    "    \n",
    "    # save the new averaged parameter\n",
    "    new_params.append(new_param)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Part C: Cleanup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    for model in params:\n",
    "        for param in model:\n",
    "            param *= 0\n",
    "\n",
    "    for model in models:\n",
    "        model.get()\n",
    "\n",
    "    for remote_index in range(len(compute_nodes)):\n",
    "        for param_index in range(len(params[remote_index])):\n",
    "            params[remote_index][param_index].set_(new_params[param_index])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Let's put it all Together!!\n",
    "\n",
    "And now that we know each step, we can put it all together into one training loop!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(epoch):\n",
    "    for data_index in range(len(remote_dataset[0])-1):\n",
    "        # update remote models\n",
    "        for remote_index in range(len(compute_nodes)):\n",
    "            data, target = remote_dataset[remote_index][data_index]\n",
    "            models[remote_index] = update(data, target, models[remote_index], optimizers[remote_index])\n",
    "\n",
    "        # encrypted aggregation\n",
    "        new_params = list()\n",
    "        for param_i in range(len(params[0])):\n",
    "            spdz_params = list()\n",
    "            for remote_index in range(len(compute_nodes)):\n",
    "                spdz_params.append(params[remote_index][param_i].copy().fix_precision().share(bob, alice, crypto_provider=james).get())\n",
    "\n",
    "            new_param = (spdz_params[0] + spdz_params[1]).get().float_precision()/2\n",
    "            new_params.append(new_param)\n",
    "\n",
    "        # cleanup\n",
    "        with torch.no_grad():\n",
    "            for model in params:\n",
    "                for param in model:\n",
    "                    param *= 0\n",
    "\n",
    "            for model in models:\n",
    "                model.get()\n",
    "\n",
    "            for remote_index in range(len(compute_nodes)):\n",
    "                for param_index in range(len(params[remote_index])):\n",
    "                    params[remote_index][param_index].set_(new_params[param_index])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test():\n",
    "    models[0].eval()\n",
    "    test_loss = 0\n",
    "    for data, target in test_loader:\n",
    "        output = models[0](data)\n",
    "        test_loss += F.mse_loss(output.view(-1), target, reduction='sum').item() # sum up batch loss\n",
    "        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability\n",
    "        \n",
    "    test_loss /= len(test_loader.dataset)\n",
    "    print('Test set: Average loss: {:.4f}\\n'.format(test_loss))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = time.time()\n",
    "\n",
    "for epoch in range(args.epochs):\n",
    "    print(f\"Epoch {epoch + 1}\")\n",
    "    train(epoch)\n",
    "    test()\n",
    "\n",
    "    \n",
    "total_time = time.time() - t\n",
    "print('Total', round(total_time, 2), 's')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Congratulations!!! - Time to Join the Community!\n",
    "\n",
    "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement toward privacy preserving, decentralized ownership of AI and the AI supply chain (data), you can do so in the following ways!\n",
    "\n",
    "### Star PySyft on Github\n",
    "\n",
    "The easiest way to help our community is just by starring the Repos! This helps raise awareness of the cool tools we're building.\n",
    "\n",
    "- [Star PySyft](https://github.com/OpenMined/PySyft)\n",
    "\n",
    "### Join our Slack!\n",
    "\n",
    "The best way to keep up to date on the latest advancements is to join our community! You can do so by filling out the form at [http://slack.openmined.org](http://slack.openmined.org)\n",
    "\n",
    "### Join a Code Project!\n",
    "\n",
    "The best way to contribute to our community is to become a code contributor! At any time you can go to PySyft Github Issues page and filter for \"Projects\". This will show you all the top level Tickets giving an overview of what projects you can join! If you don't want to join a project, but you would like to do a bit of coding, you can also look for more \"one off\" mini-projects by searching for github issues marked \"good first issue\".\n",
    "\n",
    "- [PySyft Projects](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3AProject)\n",
    "- [Good First Issue Tickets](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
    "\n",
    "### Donate\n",
    "\n",
    "If you don't have time to contribute to our codebase, but would still like to lend support, you can also become a Backer on our Open Collective. All donations go toward our web hosting and other community expenses such as hackathons and meetups!\n",
    "\n",
    "[OpenMined's Open Collective Page](https://opencollective.com/openmined)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
